/*
 * Copyright (C) 2016 Edward Raff
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package com.jstarcraft.ai.jsat.datatransform.visualization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicLong;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.datatransform.DataTransform;
import com.jstarcraft.ai.jsat.distributions.Uniform;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.linear.distancemetrics.DistanceMetric;
import com.jstarcraft.ai.jsat.linear.distancemetrics.EuclideanDistance;
import com.jstarcraft.ai.jsat.utils.SystemInfo;
import com.jstarcraft.ai.jsat.utils.concurrent.ParallelUtils;
import com.jstarcraft.ai.jsat.utils.random.RandomUtil;

/**
 * LargeViz is an algorithm for creating low dimensional embeddings for
 * visualization. It is meant to be faster and better quality than {@link TSNE
 * t-SNE} without requiring any parameter tuning to get good results. LargeViz
 * is related to t-SNE in how the neighbor graph is constructed, and the
 * {@link #setPerplexity(double) perplexity} parameter for LargeViz has the same
 * meaning and impact as the perplexity parameter in t-SNE.<br>
 * <br>
 * NOTE: The origina LargeViz paper includes a faster scheme for approximately
 * constructing the nearest neighbor graph. This is not yet implemented, but has
 * no negative impact on the quality of the result. <br>
 * See: Tang, J., Liu, J., Zhang, M., & Mei, Q. (2016). Visualizing Large-scale
 * and High-dimensional Data. In Proceedings of the 25th International
 * Conference on World Wide Web (pp. 287–297). Republic and Canton of Geneva,
 * Switzerland: International World Wide Web Conferences Steering Committee.
 * doi:10.1145/2872427.2883041
 * 
 * @author Edward Raff
 */
public class LargeViz implements VisualizationTransform {
    private DistanceMetric dm_source = new EuclideanDistance();
    private DistanceMetric dm_embed = new EuclideanDistance();
    private double perplexity = 50;
    private int dt = 2;

    /**
     * This is the number of negative samples to take for each vertex <br>
     * "number of negative samples is set as 5"
     */
    private int M = 5;

    /**
     * "γ is set as 7 by default"
     */
    private double gamma = 7;

    /**
     * Sets the target perplexity of the gaussian used over each data point. The
     * perplexity can be thought of as a quasi desired number of nearest neighbors
     * to be considered, but is adapted based on the distribution of the data.
     * Increasing the perplexity can increase the amount of time it takes to get an
     * embedding. Using a value in the range of [5, 100] is recommended.
     *
     * @param perplexity the quasi number of neighbors to consider for each data
     *                   point
     */
    public void setPerplexity(double perplexity) {
        if (perplexity <= 0 || Double.isNaN(perplexity) || Double.isInfinite(perplexity))
            throw new IllegalArgumentException("perplexity must be positive, not " + perplexity);
        this.perplexity = perplexity;
    }

    /**
     * 
     * @return the target perplexity to use for each data point
     */
    public double getPerplexity() {
        return perplexity;
    }

    /**
     * Sets the distance metric to use for the original space. This will determine
     * the target nearest neighbors to keep close to each other in the embedding
     * space
     * 
     * @param dm the distance metric to use
     */
    public void setDistanceMetricSource(DistanceMetric dm) {
        this.dm_source = dm;
    }

    /**
     * Sets the distance metric to use for the embedded space. This will determine
     * the actual nearest neighbors as the occur in the embedded space.
     * 
     * @param dm the distance metric to use
     */
    public void setDistanceMetricEmbedding(DistanceMetric dm) {
        this.dm_embed = dm;
    }

    /**
     * Sets the number of negative neighbor samples to obtain for each data point.
     * The default recommended value is 5.
     *
     * @param M the number of negative samples to use for each update
     */
    public void setNegativeSamples(int M) {
        if (M < 1)
            throw new IllegalArgumentException("Number of negative samples must be positive, not " + M);
        this.M = M;
    }

    /**
     * 
     * @return the number of negative samples to use for each update
     */
    public int getNegativeSamples() {
        return M;
    }

    /**
     * Gamma controls the negative weight assigned to negative edges in the
     * optimization problem. Large values will place a higher emphasis on separating
     * non-neighbors in the embedded space. The default recommend value is 7.
     *
     * @param gamma the weight for negative edge samples
     */
    public void setGamma(double gamma) {
        if (Double.isInfinite(gamma) || Double.isNaN(gamma) || gamma <= 0)
            throw new IllegalArgumentException("Gamma must be positive, not " + gamma);
        this.gamma = gamma;
    }

    /**
     * 
     * @return the weight for negative edge samples
     */
    public double getGamma() {
        return gamma;
    }

    @Override
    public int getTargetDimension() {
        return dt;
    }

    @Override
    public boolean setTargetDimension(int target) {
        if (target < 2)
            return false;
        dt = target;
        return true;
    }

    @Override
    public <Type extends DataSet> Type transform(DataSet<Type> d, boolean parallel) {
        Random rand = RandomUtil.getRandom();
        final ThreadLocal<Random> local_rand = ThreadLocal.withInitial(RandomUtil::getRandom);

        final int N = d.size();
        // If perp set too big, the search size would be larger than the dataset size.
        // So min to N
        /**
         * form sec 4.1: "we compute the sparse approximation by finding the floor(3u)
         * nearest neighbors of each of the N input objects (recall that u is the
         * perplexity of the conditional distributions)"
         */
        final int knn = (int) Math.min(Math.floor(3 * perplexity), N - 1);

        /**
         * P_ij does not change at this point, so lets compute these values only once
         * please! j index matches up to the value stored in nearMe. Would be W_ij in
         * notation of LargeViz paper, but P_ij form TSNE paper
         */
        final double[][] nearMePij = new double[N][knn];

        /**
         * Each row is the set of 3*u indices returned by the NN search
         */
        final int[][] nearMe = new int[N][knn];

        TSNE.computeP(d, parallel, rand, knn, nearMe, nearMePij, dm_source, perplexity);

        final double[][] nearMeSample = new double[N][knn];

        /**
         * Array of the sample weights used to perform the negative sampling.
         * 
         * Initial value is out-degree defined in LINE paper, section 4.1.2.
         */
        final double[] negSampleWeight = new double[N];

        double negSum = 0;
        for (int i = 0; i < N; i++) {
            double sum = DenseVector.toDenseVec(nearMePij[i]).sum();
            sum += nearMePij[i].length * Double.MIN_VALUE;
            negSampleWeight[i] = sum;

            nearMeSample[i][0] = nearMePij[i][0];
            for (int j = 1; j < knn; j++)// make cumulative
                nearMeSample[i][j] = Math.ulp(nearMePij[i][j]) + nearMePij[i][j] + nearMeSample[i][j - 1];
            for (int j = 1; j < knn; j++)// normalize
                nearMeSample[i][j] /= sum;
            negSampleWeight[i] = Math.pow(negSampleWeight[i], 0.75);
            negSum += negSampleWeight[i];
            if (i > 0)
                negSampleWeight[i] += negSampleWeight[i - 1];
        }
        // normalize to [0, 1] range
        for (int i = 0; i < N; i++)
            negSampleWeight[i] /= negSum;

        final List<Vec> embeded = new ArrayList<>();
        Uniform initDistribution = new Uniform(-0.00005 / dt, 0.00005 / dt);
        for (int i = 0; i < N; i++)
            embeded.add(initDistribution.sampleVec(dt, rand));

        /**
         * Number of threads to use. Paper suggests asynch updates and just ignore
         * unsafe alters b/c diff should be minor. Adding some extra logic so that we
         * have at least a good handful of points per thread to avoid excessive edits on
         * small datasets.
         */
        final int threads_to_use = Math.max(Math.min(N / (200 * M), SystemInfo.LogicalCores), 1);

        final CountDownLatch latch = new CountDownLatch(threads_to_use);

        /*
         * Objective is w*(log(1/(1+g(x)^2)) + y log(1−1/(1+g(x)^2 ))) where g(x) is the
         * euclidean distance adn G(x) is g(x)^2
         * 
         * d/x of ||x-y|| = (x-y)/||x-y|| d/y of ||x-y|| = -(x-y)/||x-y||
         * 
         * left hand side derivative of log(1/(1+g(x))) = = -(2 g(x) g'(x))/(g(x)^2+1) =
         * -(2 ||x-y|| (x-y)/||x-y||)/(||x-y||^2+1) = -(2 (x-y))/(||x-y||^2+1) for d/y =
         * -(2 (y-x))/(||x-y||^2+1)
         *
         * Right hand side portion derivative of z* log(1-1/(1+g(x)^2)) = (2 z
         * g'(x))/(g(x) (g(x)^2+1)) = (2 z (x-y))/(||x-y||^2 (||x-y||^2+1)) or for d/y =
         * (2 z (y-x))/(||x-y||^2 (||x-y||^2+1))
         * 
         * NOTE: My derivative dosn't work. But adding an extra multiplication by
         * ||x-y|| seems to fix everything? Want to come back and figure this out
         * better.
         */

        final double eta_0 = 1.0;
        final long iterations = 1000L * N;
        final ThreadLocal<Vec> local_grad_i = ThreadLocal.withInitial(() -> new DenseVector(dt));
        final ThreadLocal<Vec> local_grad_j = ThreadLocal.withInitial(() -> new DenseVector(dt));
        final ThreadLocal<Vec> local_grad_k = ThreadLocal.withInitial(() -> new DenseVector(dt));

        AtomicLong curIteration = new AtomicLong();
        ParallelUtils.run(parallel, N, (start, end) -> {
            Random l_rand = local_rand.get();
            // b/c indicies are selected at random everyone can use same iterator order
            // more important is to make sure the range length is the same so that
            // eta has the same range and effect in aggregate
            // To avoid issues with large datests, we want 1000 * N iterations
            // so do an iteration of our N/P 1000 times to get the correct amount
            for (int moreTimes = 0; moreTimes < 1000; moreTimes++)
                for (int iter = start; iter < end; iter++) {
                    double eta = eta_0 * (1 - curIteration.getAndIncrement() / (double) iterations);
                    eta = Math.max(eta, 0.0001);

                    int i = l_rand.nextInt(N);
                    // sample neighbor weighted by distance
                    int j = Arrays.binarySearch(nearMeSample[i], l_rand.nextDouble());
                    if (j < 0)
                        j = -(j) - 1;
                    if (j >= knn)/// oops. Can be hard to sample / happen with lots of near by near 0 dists
                    {
                        // lets fall back to picking someone at random
                        j = l_rand.nextInt(knn);
                    }
                    j = nearMe[i][j];

                    Vec y_i = embeded.get(i);
                    Vec y_j = embeded.get(j);
                    // right hand side update for the postive sample
                    final double dist_ij = dm_embed.dist(i, j, embeded, null);
                    final double dist_ij_sqrd = dist_ij * dist_ij;
                    if (dist_ij <= 0)
                        continue;// how did that happen?

                    Vec grad_i = local_grad_i.get();
                    Vec grad_j = local_grad_j.get();
                    Vec grad_k = local_grad_k.get();
                    y_i.copyTo(grad_j);
                    grad_j.mutableSubtract(y_j);
                    grad_j.mutableMultiply(-2 * dist_ij / (dist_ij_sqrd + 1));

                    grad_j.copyTo(grad_i);

                    // negative sampling time
                    for (int k = 0; k < M; k++) {
                        int jk = -1;
                        do {
                            jk = Arrays.binarySearch(negSampleWeight, l_rand.nextDouble());
                            if (jk < 0)
                                jk = -(jk) - 1;

                            if (jk == i || jk == j)
                                jk = -1;

                            // code to reject neighbors for sampling if too close
                            // Not sure if this code helps or hurts... not mentioned in paper
                            for (int search = 0; search < nearMe[i].length; search++)
                                if (nearMe[i][search] == jk && nearMeSample[i][search] < 0.98) {
                                    jk = -1;// too close to me!
                                    break;
                                }
                        } while (jk < 0);
                        // (2 z (y-x))/(||x-y||^2 (||x-y||^2+1))

                        Vec y_k = embeded.get(jk);
                        final double dist_ik = dm_embed.dist(i, jk, embeded, null);// dist(y_i, y_k);
                        final double dist_ik_sqrd = dist_ik * dist_ik;
                        if (dist_ik < 1e-12)
                            continue;

                        y_i.copyTo(grad_k);
                        grad_k.mutableSubtract(y_k);
                        grad_k.mutableMultiply(2 * gamma / (dist_ik * (dist_ik_sqrd + 1)));

                        grad_i.mutableAdd(grad_k);

                        y_k.mutableSubtract(eta, grad_k);

                    }

                    y_i.mutableAdd(eta, grad_i);
                    y_j.mutableAdd(-eta, grad_j);
                }
        });

        DataSet<Type> toRet = d.shallowClone();

        final IdentityHashMap<DataPoint, Integer> indexMap = new IdentityHashMap<>(N);
        for (int i = 0; i < N; i++)
            indexMap.put(d.getDataPoint(i), i);

        toRet.applyTransform(new DataTransform() {
            @Override
            public DataPoint transform(DataPoint dp) {
                int i = indexMap.get(dp);

                return new DataPoint(embeded.get(i), dp.getCategoricalValues(), dp.getCategoricalData());
            }

            @Override
            public void fit(DataSet data) {

            }

            @Override
            public DataTransform clone() {
                return this;
            }
        });

        return (Type) toRet;
    }

}